{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([[0, 0, 0, 0, 0, 0, 0],\n",
       "        [1, 0, 1, 0, 0, 0, 0],\n",
       "        [1, 0, 0, 0, 0, 0, 0],\n",
       "        [0, 1, 0, 0, 1, 1, 0],\n",
       "        [1, 1, 0, 1, 1, 1, 0],\n",
       "        [0, 2, 2, 0, 2, 1, 1],\n",
       "        [2, 1, 1, 1, 0, 0, 1],\n",
       "        [1, 1, 0, 0, 1, 1, 1],\n",
       "        [2, 0, 0, 2, 2, 0, 1],\n",
       "        [0, 0, 1, 1, 1, 0, 1]]),\n",
       " array([[0, 0, 1, 0, 0, 0, 0],\n",
       "        [2, 0, 0, 0, 0, 0, 0],\n",
       "        [1, 1, 0, 0, 1, 0, 0],\n",
       "        [1, 1, 1, 1, 1, 0, 1],\n",
       "        [2, 2, 2, 2, 2, 0, 1],\n",
       "        [2, 0, 0, 2, 2, 1, 1],\n",
       "        [0, 1, 0, 1, 0, 0, 1]]))"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import numpy as np\n",
    "\n",
    "\n",
    "#加载数据集\n",
    "def load_data():\n",
    "    with open('决策树数据.txt') as fr:\n",
    "        lines = fr.readlines()\n",
    "\n",
    "    x = np.empty((len(lines), 7), dtype=int)\n",
    "\n",
    "    for i in range(len(lines)):\n",
    "        line = lines[i].strip().split(',')\n",
    "        x[i] = line\n",
    "\n",
    "    test_x = x[10:]\n",
    "    x = x[:10]\n",
    "\n",
    "    return x, test_x\n",
    "\n",
    "\n",
    "x, test_x = load_data()\n",
    "x, test_x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1.0"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#计算数据集的熵,当然这个熵是针对于y来说的\n",
    "def get_entropy(_x):\n",
    "    entropy = 0\n",
    "\n",
    "    #统计y的熵就可以了\n",
    "    y = _x[:, -1]\n",
    "\n",
    "    #统计每个结果出现的次数,[5 5],表示0出现5次,1出现5次\n",
    "    bincount = np.bincount(y)  #[5 5]\n",
    "    for count in bincount:\n",
    "        if count == 0:\n",
    "            continue\n",
    "\n",
    "        #出现次数 / 总次数 = 出现概率\n",
    "        prob = count / len(_x)\n",
    "\n",
    "        #熵 = p * log(p) * -1\n",
    "        entropy -= prob * np.log2(prob)\n",
    "\n",
    "    return entropy\n",
    "\n",
    "\n",
    "get_entropy(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.2754887502163468"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def get_gain(_x, col):\n",
    "    #列熵\n",
    "    col_entropy = 0\n",
    "\n",
    "    #根据列值,把数据分成n份\n",
    "    for value in set(_x[:, col]):\n",
    "        x_by_col_and_value = _x[_x[:, col] == value]\n",
    "\n",
    "        #这个数据子集出现的概率,很显然,等于出现次数/总次数\n",
    "        prob = len(x_by_col_and_value) / len(_x)\n",
    "\n",
    "        #求数据子集的熵\n",
    "        entropy = get_entropy(x_by_col_and_value)\n",
    "\n",
    "        #这个列的熵,等于这个式子的累计\n",
    "        col_entropy += prob * entropy\n",
    "\n",
    "    #信息增益,就是切分数据后,熵值能下降多少,这个值越大越好\n",
    "    gain = get_entropy(_x) - col_entropy\n",
    "\n",
    "    #用这个就是id3决策树,他倾向于选择可取值多的列\n",
    "    return gain\n",
    "\n",
    "\n",
    "get_gain(x, 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def get_split_col(_x):\n",
    "    best_col = -1\n",
    "    best_gain = 0\n",
    "\n",
    "    #遍历所有的列,最后一列是y,不需要计算\n",
    "    for col in range(_x.shape[1] - 1):\n",
    "\n",
    "        #信息增益,就是切分数据后,熵值能下降多少,这个值越大越好\n",
    "        gain = get_gain(_x, col)\n",
    "\n",
    "        #信息增益最大的列,就是要被拆分的列\n",
    "        if gain > best_gain:\n",
    "            best_gain = gain\n",
    "            best_col = col\n",
    "\n",
    "    return best_col\n",
    "\n",
    "\n",
    "get_split_col(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Node col=0\n",
      "Leaf y=1\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(None, None)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#创建节点和叶子对象,用来构建树\n",
    "class Node():\n",
    "    def __init__(self, col):\n",
    "        self.col = col\n",
    "        self.children = {}\n",
    "\n",
    "    def __str__(self):\n",
    "        return 'Node col=%d' % self.col\n",
    "\n",
    "\n",
    "class Leaf():\n",
    "    def __init__(self, y):\n",
    "        self.y = y\n",
    "\n",
    "    def __str__(self):\n",
    "        return 'Leaf y=%d' % self.y\n",
    "\n",
    "\n",
    "print(Node(0)), print(Leaf(1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "---- Node col=0 \n"
     ]
    }
   ],
   "source": [
    "#打印树的方法\n",
    "def print_tree(node, prefix='', subfix=''):\n",
    "    prefix += '-' * 4\n",
    "    print(prefix, node, subfix)\n",
    "    if isinstance(node, Leaf):\n",
    "        return\n",
    "    for i in node.children:\n",
    "        subfix = 'value=' + str(i)\n",
    "        print_tree(node.children[i], prefix, subfix)\n",
    "\n",
    "\n",
    "print_tree(Node(0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#先在所有数据上求最大信息增益的列,结果是0\n",
    "get_split_col(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Node col=0\n"
     ]
    }
   ],
   "source": [
    "#根据上面的结果,创建根节点,根节点根据列0的值来分割数据\n",
    "root = Node(0)\n",
    "print(root)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "---- Node col=0 \n",
      "-------- Node col=2 value=0\n",
      "-------- Node col=1 value=1\n",
      "-------- Leaf y=1 value=2\n"
     ]
    }
   ],
   "source": [
    "#添加子节点的方法\n",
    "def create_children(_x, parent_node):\n",
    "\n",
    "    #遍历父节点col列所有的取值\n",
    "    for split_value in np.unique(_x[:, parent_node.col]):\n",
    "\n",
    "        #首先根据父节点col列的取值分割数据\n",
    "        sub_x = _x[_x[:, parent_node.col] == split_value]\n",
    "\n",
    "        #取去重y值\n",
    "        unique_y = np.unique(sub_x[:, -1])\n",
    "\n",
    "        #如果所有的y都是一样的,说明是个叶子节点\n",
    "        if len(unique_y) == 1:\n",
    "            parent_node.children[split_value] = Leaf(unique_y[0])\n",
    "            continue\n",
    "\n",
    "        #否则,是个分支节点,计算最佳切分列\n",
    "        split_col = get_split_col(sub_x)\n",
    "\n",
    "        #添加分支节点到父节点上\n",
    "        parent_node.children[split_value] = Node(col=split_col)\n",
    "\n",
    "\n",
    "create_children(x, root)\n",
    "\n",
    "print_tree(root)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "---- Node col=0 \n",
      "-------- Node col=2 value=0\n",
      "------------ Leaf y=0 value=0\n",
      "------------ Leaf y=1 value=1\n",
      "------------ Leaf y=1 value=2\n",
      "-------- Node col=1 value=1\n",
      "-------- Leaf y=1 value=2\n"
     ]
    }
   ],
   "source": [
    "#继续创建,0=0节点的下一层\n",
    "x_0_0 = x[x[:, 0] == 0]\n",
    "create_children(x_0_0, root.children[0])\n",
    "\n",
    "print_tree(root)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "---- Node col=0 \n",
      "-------- Node col=2 value=0\n",
      "------------ Leaf y=0 value=0\n",
      "------------ Leaf y=1 value=1\n",
      "------------ Leaf y=1 value=2\n",
      "-------- Node col=1 value=1\n",
      "------------ Leaf y=0 value=0\n",
      "------------ Node col=3 value=1\n",
      "-------- Leaf y=1 value=2\n"
     ]
    }
   ],
   "source": [
    "#继续创建,0=1的下一层\n",
    "x_0_1 = x[x[:, 0] == 1]\n",
    "create_children(x_0_1, root.children[1])\n",
    "\n",
    "print_tree(root)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "---- Node col=0 \n",
      "-------- Node col=2 value=0\n",
      "------------ Leaf y=0 value=0\n",
      "------------ Leaf y=1 value=1\n",
      "------------ Leaf y=1 value=2\n",
      "-------- Node col=1 value=1\n",
      "------------ Leaf y=0 value=0\n",
      "------------ Node col=3 value=1\n",
      "---------------- Leaf y=1 value=0\n",
      "---------------- Leaf y=0 value=1\n",
      "-------- Leaf y=1 value=2\n"
     ]
    }
   ],
   "source": [
    "#继续创建,0=1,1=1的下一层\n",
    "x_0_1_and_1_1 = x_0_1[x_0_1[:, 1] == 1]\n",
    "create_children(x_0_1_and_1_1, root.children[1].children[1])\n",
    "\n",
    "print_tree(root)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.0\n",
      "-------------------------\n",
      "0.2857142857142857\n"
     ]
    }
   ],
   "source": [
    "#预测方法,测试\n",
    "def pred(_x, node):\n",
    "    col_value = _x[node.col]\n",
    "    node = node.children[col_value]\n",
    "\n",
    "    if isinstance(node, Leaf):\n",
    "        return node.y\n",
    "\n",
    "    return pred(_x, node)\n",
    "\n",
    "\n",
    "correct = 0\n",
    "for i in x:\n",
    "    if pred(i, root) == i[-1]:\n",
    "        correct += 1\n",
    "\n",
    "print(correct / len(x))\n",
    "\n",
    "print('-------------------------')\n",
    "\n",
    "correct = 0\n",
    "for i in test_x:\n",
    "    if pred(i, root) == i[-1]:\n",
    "        correct += 1\n",
    "\n",
    "print(correct / len(test_x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "\n",
    "#序列化保存下来,后面剪枝用\n",
    "with open('tree.dump', 'wb') as fr:\n",
    "    pickle.dump(root, fr)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
